import random
import unicodedata
from torch.utils.data import Dataset, DataLoader
from config import *


class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {'<unk>': 0, '<pad>': 1, '<bos>': 2, '<eos>': 3}
        self.word2count = {}
        self.index2word = {0: '<unk>', 1: '<pad>', 2: '<bos>', 3: '<eos>'}
        self.n_words = 4  # num of words, Count SOS and EOS

    def addSentence(self, sentenceList):
        for word in sentenceList:
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

    def word2idx(self, word):
        return self.word2index.setdefault(word, 0)

    def __len__(self):
        return self.n_words


# 为便于数据处理，把Unicode字符串转换为ASCII编码
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )


def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and len(p[1].split(' ')) < MAX_LENGTH


def readLangs(file_name='train.txt'):
    print("Reading lines...")

    # 读文件，然后分成行
    lines = open(DATA_ROOT + file_name, encoding='utf-8').read().strip().split('\n')

    # 把行分成语句对，并进行规范化
    lines = [[unicodeToAscii(s.lower().strip()) for s in l.split('__eou__') if s] for l in lines]
    src = []
    tgt = []
    for l in lines:
        if filterPair(l):
            src += [l[0].split()]
            tgt += [l[1].split()]

    lang = Lang("dialog")
    # 判断是否需要转换语句对的次序，如[英文，中文]转换为[中文，英文]次序

    return lang, src, tgt


def truncate_pad(line, num_steps, padding_token):
    """截断或填充文本序列"""
    if len(line) > num_steps:
        return line[:num_steps]  # 截断
    return line + [padding_token] * (num_steps - len(line))  # 填充


def build_array_nmt(lines, lang, num_steps):
    def voc(line):
        return [lang.word2idx(i) for i in line]

    # 将机器翻译的文本序列转换成小批量
    lines = [voc(l) for l in lines]
    lines = [l + [3] for l in lines]
    array = torch.tensor([truncate_pad(
        l, num_steps, 1) for l in lines])
    # 非pad处求和即为valid_len
    valid_len = (array != 1).type(torch.int32).sum(1)
    return array, valid_len


class MyDataset(Dataset):
    def __init__(self, src, tgt, lang, num_steps):  # 构造函数
        # [[how, are, you, ?],[...]...] -> [[8,9,10,11,3,1,1,1,1,1]],[...]...]
        self.src_array, self.src_valid_len = build_array_nmt(src, lang, num_steps)
        self.tgt_array, self.tgt_valid_len = build_array_nmt(tgt, lang, num_steps)

    #  src, len, tgt, len
    def __getitem__(self, idx):  # 读取数据
        return self.src_array[idx], self.src_valid_len[idx], self.tgt_array[idx], self.tgt_valid_len[idx]

    def __len__(self):
        return len(self.src_array)


def prepareData(num_steps, file_name='train.txt'):
    lang, src, tgt = readLangs(file_name)
    print("Load %s sentence pairs" % len(src))
    print("Counting words...")
    for (s, t) in zip(src, tgt):
        lang.addSentence(s)
        lang.addSentence(t)
    i = random.randint(0, len(src))
    print(src[i], tgt[i])
    print("Counted words:")
    print(lang.name, lang.n_words)
    data = MyDataset(src, tgt, lang, num_steps)

    print(data[i])
    return lang, data


def pretest(num_steps, lang, file_name):
    lang0, src, tgt = readLangs(file_name)
    print("Load %s sentence pairs" % len(src))
    data = MyDataset(src, tgt, lang, num_steps)
    return data


def loaddata(dataset, batch_size, is_train=True):
    return DataLoader(dataset, batch_size, shuffle=is_train)


if __name__ == "__main__":
    lang, data = prepareData(num_steps)
    torch.save(lang, DATA_ROOT + "dialog.lang")
    torch.save(data, DATA_ROOT + "dataset")
    data = pretest(num_steps, lang, 'test.txt')
    torch.save(data, DATA_ROOT + "testdata")
    data = pretest(num_steps, lang, 'valid.txt')
    torch.save(data, DATA_ROOT + "validdata")
